{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An Introduction to Models in Pyro\n",
    "\n",
    "The basic unit of probabilistic programs is the _stochastic function_. \n",
    "This is an arbitrary Python callable that combines two ingredients:\n",
    "\n",
    "- deterministic Python code; and\n",
    "- primitive stochastic functions that call a random number generator\n",
    "\n",
    "Concretely, a stochastic function can be any Python object with a `__call__()` method, like a function, a method, or a PyTorch `nn.Module`.\n",
    "\n",
    "Throughout the tutorials and documentation, we will often call stochastic functions *models*, since stochastic functions can be used to represent simplified or abstract descriptions of a process by which data are generated.  Expressing models as stochastic functions means that models can be composed, reused, imported, and serialized just like regular Python callables. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pyro\n",
    "\n",
    "pyro.set_rng_seed(101)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Primitive Stochastic Functions\n",
    "\n",
    "Primitive stochastic functions, or distributions, are an important class of stochastic functions for which we can explicitly compute the probability of the outputs given the inputs.  As of PyTorch 0.4 and Pyro 0.2, Pyro uses PyTorch's [distribution library](http://pytorch.org/docs/master/distributions.html). You can also create custom distributions using [transforms](http://pytorch.org/docs/master/distributions.html#module-torch.distributions.transforms).\n",
    "\n",
    "Using primitive stochastic functions is easy. For example, to draw a sample `x` from the unit normal distribution $\\mathcal{N}(0,1)$ we do the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample tensor(-1.3905)\n",
      "log prob tensor(-1.8857)\n"
     ]
    }
   ],
   "source": [
    "loc = 0.   # mean zero\n",
    "scale = 1. # unit variance\n",
    "normal = torch.distributions.Normal(loc, scale) # create a normal distribution object\n",
    "x = normal.rsample() # draw a sample from N(0,1)\n",
    "print(\"sample\", x)\n",
    "print(\"log prob\", normal.log_prob(x)) # score the sample from N(0,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, `torch.distributions.Normal` is an instance of the `Distribution` class that takes parameters and provides sample and score methods. Pyro's distribution library `pyro.distributions` is a thin wrapper around `torch.distributions` because we want to make use of PyTorch's fast tensor math and autograd capabilities during inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A Simple Model\n",
    "\n",
    "All probabilistic programs are built up by composing primitive stochastic functions and deterministic computation. Since we're ultimately interested in probabilistic programming because we want to model things in the real world, let's start with a model of something concrete. \n",
    "\n",
    "Let's suppose we have a bunch of data with daily mean temperatures and cloud cover. We want to reason about how temperature interacts with whether it was sunny or cloudy. A simple stochastic function that describes how that data might have been generated is given by:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weather():\n",
    "    cloudy = torch.distributions.Bernoulli(0.3).sample()\n",
    "    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'\n",
    "    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]\n",
    "    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]\n",
    "    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()\n",
    "    return cloudy, temp.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's go through this line-by-line. First, in lines 2 we define a binary random variable 'cloudy', which is given by a draw from the Bernoulli distribution with a parameter of `0.3`. Since the Bernoulli distributions return `0`s or `1`s, in line 3 we convert the value `cloudy` to a string so that return values of `weather` are easier to parse. So according to this model 30% of the time it's cloudy and 70% of the time it's sunny.\n",
    "\n",
    "In lines 4-5 we define the parameters we're going to use to sample the temperature in lines 6. These parameters depend on the particular value of `cloudy` we sampled in line 2. For example, the mean temperature is 55 degrees (Fahrenheit) on cloudy days and 75 degrees on sunny days. Finally we return the two values `cloudy` and `temp` in line 7.\n",
    "\n",
    "However, `weather` is entirely independent of Pyro - it only calls PyTorch. We need to turn it into a Pyro program if we want to use this model for anything other than sampling fake data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The `pyro.sample` Primitive\n",
    "\n",
    "To turn `weather` into a Pyro program, we'll replace the `torch.distribution`s with `pyro.distribution`s and the `.sample()` and `.rsample()` calls with calls to `pyro.sample`, one of the core language primitives in Pyro. Using `pyro.sample` is as simple as calling a primitive stochastic function with one important difference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-0.8152)\n"
     ]
    }
   ],
   "source": [
    "x = pyro.sample(\"my_sample\", pyro.distributions.Normal(loc, scale))\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just like a direct call to `torch.distributions.Normal().rsample()`, this returns a sample from the unit normal distribution. The crucial difference is that this sample is _named_. Pyro's backend uses these names to uniquely identify sample statements and _change their behavior at runtime_ depending on how the enclosing stochastic function is being used. As we will see, this is how Pyro can implement the various manipulations that underlie inference algorithms."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've introduced `pyro.sample` and `pyro.distributions` we can rewrite our simple model as a Pyro program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('cloudy', 64.5440444946289)\n",
      "('sunny', 94.37557983398438)\n",
      "('sunny', 72.5186767578125)\n"
     ]
    }
   ],
   "source": [
    "def weather():\n",
    "    cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))\n",
    "    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'\n",
    "    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]\n",
    "    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]\n",
    "    temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))\n",
    "    return cloudy, temp.item()\n",
    "\n",
    "for _ in range(3):\n",
    "    print(weather())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Procedurally, `weather()` is still a non-deterministic Python callable that returns two random samples. Because the randomness is now invoked with `pyro.sample`, however, it is much more than that. In particular `weather()` specifies a joint probability distribution over two named random variables: `cloudy` and `temp`. As such, it defines a probabilistic model that we can reason about using the techniques of probability theory. For example we might ask: if I observe a temperature of 70 degrees, how likely is it to be cloudy? How to formulate and answer these kinds of questions will be the subject of the next tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Universality: Stochastic Recursion, Higher-order Stochastic Functions, and Random Control Flow\n",
    "\n",
    "We've now seen how to define a simple model. Building off of it is easy. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ice_cream_sales():\n",
    "    cloudy, temp = weather()\n",
    "    expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.\n",
    "    ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0))\n",
    "    return ice_cream"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This kind of modularity, familiar to any programmer, is obviously very powerful. But is it powerful enough to encompass all the different kinds of models we'd like to express?\n",
    "\n",
    "It turns out that because Pyro is embedded in Python, stochastic functions can contain arbitrarily complex deterministic Python and randomness can freely affect control flow. For example, we can construct recursive functions that terminate their recursion nondeterministically, provided we take care to pass `pyro.sample` unique sample names whenever it's called. For example we can define a geometric distribution that counts the number of failures until the first success like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "def geometric(p, t=None):\n",
    "    if t is None:\n",
    "        t = 0\n",
    "    x = pyro.sample(\"x_{}\".format(t), pyro.distributions.Bernoulli(p))\n",
    "    if x.item() == 1:\n",
    "        return 0\n",
    "    else:\n",
    "        return 1 + geometric(p, t + 1)\n",
    "    \n",
    "print(geometric(0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the names `x_0`, `x_1`, etc., in `geometric()` are generated dynamically and that different executions can have different numbers of named random variables. \n",
    "\n",
    "We are also free to define stochastic functions that accept as input or produce as output other stochastic functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.1493)\n"
     ]
    }
   ],
   "source": [
    "def normal_product(loc, scale):\n",
    "    z1 = pyro.sample(\"z1\", pyro.distributions.Normal(loc, scale))\n",
    "    z2 = pyro.sample(\"z2\", pyro.distributions.Normal(loc, scale))\n",
    "    y = z1 * z2\n",
    "    return y\n",
    "\n",
    "def make_normal_normal():\n",
    "    mu_latent = pyro.sample(\"mu_latent\", pyro.distributions.Normal(0, 1))\n",
    "    fn = lambda scale: normal_product(mu_latent, scale)\n",
    "    return fn\n",
    "\n",
    "print(make_normal_normal()(1.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here `make_normal_normal()` is a stochastic function that takes one argument and which, upon execution, generates three named random variables.\n",
    "\n",
    "The fact that Pyro supports arbitrary Python code like this&mdash;iteration, recursion, higher-order functions, etc.&mdash;in conjuction with random control flow means that Pyro stochastic functions are _universal_, i.e. they can be used to represent any computable probability distribution. As we will see in subsequent tutorials, this is incredibly powerful. \n",
    "\n",
    "It is worth emphasizing that this is one reason why Pyro is built on top of PyTorch: dynamic computational graphs are an important ingredient in allowing for universal models that can benefit from GPU-accelerated tensor math."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "We've shown how we can use stochastic functions and primitive distributions to represent models in Pyro. In order to learn models from data and reason about them we need to be able to do inference. This is the subject of the [next tutorial](intro_part_ii.ipynb)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
